// Cuda runtime library
#include <cuda_runtime.h>
#include <curand_kernel.h>
#include "permUtility.h"
#include <cuda.h>
#include <curand.h>
#include "device_launch_parameters.h"
/* include MTGP host helper functions */
#include <curand_mtgp32_host.h>
/* include MTGP pre-computed parameter sets */
#include <curand_mtgp32dc_p_11213.h>

#include <ctime>

#ifdef LINUX
#include <inttypes.h>
#define __int64 int64_t
#endif

#define GENO_DATA_SIZE 79 * 3 * 32
#define STA_SIZE 1024

#define CUDA_CALL(x) do { if((x) != CURAND_STATUS_SUCCESS) { \
    printf("Error at %s:%d\n",__FILE__,__LINE__); \
    return EXIT_FAILURE;}} while(0)

// Global constant memory references
//__constant__ uint64 constGenoData[GENO_DATA_SIZE];
__constant__ float constSta[STA_SIZE];

struct MarginalDistr {
	int marginalCase[3];
	int marginalCtrl[3];
};

//// general helper functions
long long iDivUp(long long a, long long b) {
	return ((a % b) != 0) ? (a / b + 1) : (a / b);
}

void checkCUDAError(const char *msg) {
  cudaError_t err = cudaGetLastError();
  if( cudaSuccess != err) {
    fprintf(stderr, "Cuda error: %s: %s.\n", msg, cudaGetErrorString(err) ); 
    exit(EXIT_FAILURE); 
  }
} 

// Hamming weight
inline __device__ int dev_count_bit(uint64 i) {
	i = i - ((i >> 1) & 0x5555555555555555);
    i = (i & 0x3333333333333333) + ((i >> 2) & 0x3333333333333333);
    return (((i + (i >> 4)) & 0xF0F0F0F0F0F0F0F) * 0x101010101010101) >> 56;
}

inline __device__ int dev_count_bit_slow_mult(uint64 x) {
	x -= (x >> 1) & 0x5555555555555555;								//put count of each 2 bits into those 2 bits
    x = (x & 0x3333333333333333) + ((x >> 2) & 0x3333333333333333); //put count of each 4 bits into those 4 bits 
    x = (x + (x >> 4)) & 0x0f0f0f0f0f0f0f0f;						//put count of each 8 bits into those 8 bits 
    x += x >>  8;													//put count of each 16 bits into their lowest 8 bits
    x += x >> 16;													//put count of each 32 bits into their lowest 8 bits
    x += x >> 32;													//put count of each 64 bits into their lowest 8 bits
    return x & 0x7f;
}

#define numLabelGroup 128
__global__ void permute_marginal_kernel(uint64 *genoData, int nSample, int nLongIntSamples, int nCase, int nSNP, curandStateMtgp32 *devMTGPStates, int numPermutation, int *permResult, float *devSta, uint64 *devLabels) {
	//uint64 *labels = new uint64[nLongIntSamples];
	// TODO think about the memory used to store the labels.
	// Use register and local memory to allocate memory for the labels may leed to lack of memory.
	// Assume nSNP = 1, only permute for one SNP at first.
	//extern __shared__ uint64 sharedMem[];
	//uint64 *labels = sharedMem;
	uint64 *labels;
	//int *sharedResult = (int *)&sharedMem[nLongIntSamples];
	int result = 0;
	//int statesOffset = nLongIntSamples + nSNP * sizeof(int) / sizeof(uint64) + 1;
	// Store the RNG states to the register.
	//curandStateMtgp32 RNGStates = devMTGPStates[blockIdx.x];
	//curandStateMtgp32 RNGStates = devMTGPStates[0];
	curandStateMtgp32 *RNGStates = &devMTGPStates[blockIdx.x];
	const uint64 mask1 = 0x0000000000000001;
	// Copy random generator state to local memory.
	//int id = threadIdx.x + blockIdx.x * blockDim.x;
	const int id = threadIdx.x;
	//int totalThreads = blockDim.x * gridDim.x;
	// Need only one state object for each thread
	//state = devMTGPStates[blockIdx.x];
	if(id < numLabelGroup) {
		// initialize the labels.
		labels = &devLabels[((blockIdx.x * numLabelGroup) + id) * nLongIntSamples];
		for(int i = 0; i < nLongIntSamples; i++) {
			labels[i] = 0;
		}
		for(int i = 0; i < nCase; i++) {
			labels[i/sizeof(uint64)/8] |= mask1 << (i % (sizeof(uint64) * 8));
		}
	}
	__syncthreads();
	
	// Some preparation
	const uint64 masklow = 0x000000000000003f;
	int res = INT_MAX % nSample;
	int rand_max = INT_MAX - res;
	int cnt = 0;
	//float threshold = 27; // p value = 10^-6
	//float threshold = 4.39;
	float threshold = devSta[id];
	while(cnt < numPermutation) {
		// Permute lables
		// Only thread 0 will permute the labels in the shared memory. Other threads wait.
		for(int i = 0; i < nSample; i++) {
			int j = 0;
			int raw =curand(RNGStates);
			/*while(abs(raw) >= rand_max) {
				raw = curand(RNGStates);
			}*/
			// This solution uses only one thread to do the permutation.
			if(id < numLabelGroup) {
				labels = &devLabels[((blockIdx.x * numLabelGroup) + id) * nLongIntSamples];
				// The bug is in this bracket
				j = abs(raw) % nSample;
				// exchange labels. 
				uint64 iLabel = labels[i>>6] & (mask1 << (i & masklow));
				uint64 jLabel = labels[j>>6] & (mask1 << (j & masklow));
				if(iLabel == 0) {
					labels[j>>6] &= ~(mask1 << (j & masklow));
				} else {
					labels[j>>6] |= (mask1 << (j & masklow));
				}
				if(jLabel == 0) {
					labels[i>>6] &= ~(mask1 << (i & masklow));
				} else {
					labels[i>>6] |= (mask1 << (i & masklow));
				}
				//labels[j/sizeof(uint64)/8]++;
				//labels[0]+=j;
			}
			//j = abs(raw) % nSample;
			//labels[id % nLongIntSamples] += j;

			__syncthreads();
		}
		//__syncthreads();
		// Calculate the statistics. The id-th thread will be in charge of the id-th SNP.
		// Calculate marginal entropy. 
		// TODO: change float to double
		//double *MarginalEntropySNP, *MarginalEntropySNP_Y, *MarginalAssociation;
		float MarginalEntropySNP, MarginalEntropySNP_Y, MarginalAssociation;
		float MarginalEntropyY;
		float ptmp1 = (float) nCase/nSample;
		MarginalEntropyY =  -ptmp1 *logf(ptmp1) - (1-ptmp1) *logf(1-ptmp1);

		
		MarginalAssociation = 0;

		int i2;
		float tmp, ptmp;
		// TODO: change the array to parallel computation.
		int GenoMarginalDistr[3][2];

		if(id < nSNP)
		{
			for(int labelsGroupIndex = 0; labelsGroupIndex < numLabelGroup; labelsGroupIndex++) {
				//labels = &devLabels[blockIdx.x * nLongIntSamples * labelsGroupIndex];
				labels = &devLabels[(blockIdx.x * numLabelGroup + labelsGroupIndex) * nLongIntSamples];
				MarginalEntropySNP = 0;
				MarginalEntropySNP_Y = 0;
				for (i2 = 0; i2 < 3; i2++)
				{
					int row = id * 3 + i2;
					int cntCase = 0, cntTotal = 0, cntCaseTotal = 0;
					int index = row * nLongIntSamples;
					for(int i = 0; i < nLongIntSamples; i++) {
						// TODO: change the arrangement of the gene data in the texture memory.
						//int2 gene = tex1Dfetch(texGenoData, index + i);
						//uint64 gene_64 = ((unsigned long long)gene.y) << 32 | gene.x;
						uint64 gene_64 = genoData[index + i];
						cntCase += dev_count_bit(gene_64 & labels[i]);
						cntTotal += dev_count_bit(gene_64);
						//cntCaseTotal += dev_count_bit(labels[i]);
					}
					GenoMarginalDistr[i2][0] = cntCase;		// The number of cases in the cell X_i1 = i2
					GenoMarginalDistr[i2][1] = cntTotal - cntCase;		// The number of controls in the cell X_i1 = i2
				}

				for (i2 = 0; i2<3; i2++)
				{
					tmp = (float) GenoMarginalDistr[i2][0] + GenoMarginalDistr[i2][1];		// The number of samples whose X_i1 = i2
					if (tmp > 0)
					{
						ptmp = tmp/nSample;		// The proportion of samples whose X_i1 = i2
						MarginalEntropySNP += -(ptmp)*log(ptmp);
					}

					if (GenoMarginalDistr[i2][0]>0)
					{
						ptmp = (float) GenoMarginalDistr[i2][0]/nSample;
						MarginalEntropySNP_Y += -ptmp*log(ptmp);
					}

					if (GenoMarginalDistr[i2][1]>0)
					{
						ptmp = (float) GenoMarginalDistr[i2][1]/nSample;
						MarginalEntropySNP_Y += -ptmp*log(ptmp);
					}

				}

				MarginalAssociation = (-MarginalEntropySNP_Y + MarginalEntropySNP + MarginalEntropyY)*nSample*2;
				//MarginalAssociation = 0;
				if(MarginalAssociation > threshold) {
					result++;
				}
			}
		}
		cnt++;
		__syncthreads();
	}
	
	//if(threadIdx.x == 0)
	//	devMTGPStates[blockIdx.x] = *RNGStates;
	// store the result in the shared memory to the global memory.
	permResult[threadIdx.x + blockIdx.x * blockDim.x] = result;
}

__global__ void permute_interaction_kernel(uint64 *genoData, int nSample, int nLongIntSamples, int nCase, int nSNP, curandStateMtgp32 *devMTGPStates, int numPermutation, int *permResult, uint64 *devLabels) {
	extern __shared__ int sharedMem[];
	// *labels points to the initial address of the label groups belonging to this block.
	uint64 *labels = &devLabels[blockIdx.x * numLabelGroup * nLongIntSamples];
	int *sharedThreadResult = sharedMem;	//sharedThreadResult[numLabelGroup]
	int *sharedPermResult = sharedMem + numLabelGroup;	//sharedPermResult[nPair]
	uint64 *genoCache = (uint64 *)(sharedMem + numLabelGroup + nSNP / 2 + (nSNP / 2) % 2);
	curandStateMtgp32 *RNGStates = &devMTGPStates[blockIdx.x];
	const uint64 mask1 = 0x0000000000000001;
	const int id = threadIdx.x;
	for(int i = 0; i < nLongIntSamples; i++) {
		int localIndex = i * numLabelGroup + id;
		labels[localIndex] = 0;
	}
	for(int i = 0; i < nCase; i++) {
		int index = i/sizeof(uint64)/8;
		int localIndex = index * numLabelGroup + id;
		labels[localIndex] |= mask1 << (i % (sizeof(uint64) * 8));
	}
	sharedThreadResult[id] = 0;
	for(int i = 0; i < (nSNP / 2); i += numLabelGroup) {
		if(i + id < nSNP / 2) {
			sharedPermResult[i + id] = 0;
		}
	}
	__syncthreads();
	
	// Some preparation
	const uint64 masklow = 0x000000000000003f;
	int nCtrl = nSample - nCase;
	int cnt = 0;
	//float threshold = devSta[id];
	int i,j, localI, localJ;
	while(cnt < numPermutation) {
		// Permute lables
		for(i = 0; i < nSample; i++) {
			int raw =curand(RNGStates);
			j = i + abs(raw) % (nSample - i);
			// exchange labels. 
			localI = (i >> 6) * numLabelGroup + id;
			localJ = (j >> 6) * numLabelGroup + id;
			uint64 iLabel = labels[localI] & (mask1 << (i & masklow));
			uint64 jLabel = labels[localJ] & (mask1 << (j & masklow));
			if(jLabel == 0) {
				labels[localI] &= ~(mask1 << (i & masklow));
			} else {
				labels[localI] |= (mask1 << (i & masklow));
			}
			if(iLabel == 0) {
				labels[localJ] &= ~(mask1 << (j & masklow));
			} else {
				labels[localJ] |= (mask1 << (j & masklow));
			}				
		}
		int i1, i2, i3;
		int GenoMarginalDistr[2][3][2];
		int GenoJointDistr[18]; // Joint distribution.
		float threshold;
		for(int i = 0; i < nSNP / 2; i++) {
			// Calculate the satatistic for the i-th pair of SNPs.
			threshold = constSta[i];
			// load the genotype data of this pair into the shared memory.
			for(i1 = 0; i1 < (nLongIntSamples * 6); i1 += numLabelGroup) {
				if(i1 + id < nLongIntSamples * 6)
					genoCache[i1 + id] = genoData[nLongIntSamples * 6 * i + i1 + id];
			}
			__syncthreads();
			// Calculate the two marginal distribution.
			int cntCase = 0, cntTotal = 0;
			for(i1 = 0; i1 < 2; i1++) {
				// The SNP index is i*2+i1
				for(i2 = 0; i2 < 3; i2++) {
					// Count the cell SNP_{i*2+i1}=i2.
					// The geno data is in the {3*(i*2+i1)+i2}-th row.
					int row = 3*(i*2+i1)+i2;
					cntCase = 0;
					cntTotal = 0;
					for(int j = 0; j < nLongIntSamples; j++) {
						// All the threads will read the genodata at the same address.
						//uint64 gene_64 = genoData[row * nLongIntSamples + j];
						//uint64 gene_64 = constGenoData[row * nLongIntSamples + j];
						uint64 gene_64 = genoCache[(3 * i1 + i2) * nLongIntSamples + j];
						uint64 label_64 = labels[j * numLabelGroup + id];
						cntCase += dev_count_bit(gene_64 & label_64);
						//cntCase += dev_count_bit(label_64);
						cntTotal += dev_count_bit(gene_64);

					}
					// Only keep one pair at one time.
					GenoMarginalDistr[i1][i2][0] = cntCase;		// The number of cases in the cell X_i1 = i2
					GenoMarginalDistr[i1][i2][1] = cntTotal - cntCase;		// The number of controls in the cell X_i1 = i2
					//GenoMarginalDistr[i1][i2][1] = cntTotal;
				}
			}
			// We have got the two marginal distribution table for the 2 SNPs in the pair.
			// Start to calculate the contingency table.
			for (i1 = 0; i1<2 ; i1++) {
				for (i2 = 0; i2 <2; i2++) {
					cntCase = 0;
					cntTotal = 0;
					int j1 = (3*(i*2)+i1)*nLongIntSamples, j2 = (3*(i*2+1)+i2)*nLongIntSamples;
					for (i3 = 0; i3< nLongIntSamples; i3++) {
						/*uint64 gene_1 = genoData[j1 + i3];
						uint64 gene_2 = genoData[j2 + i3];*/
						//uint64 gene_1 = constGenoData[j1 + i3];
						//uint64 gene_2 = constGenoData[j2 + i3];
						uint64 gene_1 = genoCache[i1 * nLongIntSamples + i3];
						uint64 gene_2 = genoCache[(3 + i2) * nLongIntSamples + i3];
						uint64 label_64 = labels[i3 * numLabelGroup + id];
						uint64 tmp = gene_1 & gene_2;
						cntCase += dev_count_bit(tmp & label_64);
						cntTotal += dev_count_bit(tmp);						
					}
					GenoJointDistr[i1*3 + i2] = cntCase;
					GenoJointDistr[9 + i1*3 + i2] = cntTotal - cntCase;
				}
			}
			//for case
			GenoJointDistr[2] = GenoMarginalDistr[0][0][0] - GenoJointDistr[0] - GenoJointDistr[1];
			GenoJointDistr[5] = GenoMarginalDistr[0][1][0] - GenoJointDistr[3] - GenoJointDistr[4];
			GenoJointDistr[6] = GenoMarginalDistr[1][0][0] - GenoJointDistr[0] - GenoJointDistr[3];
			GenoJointDistr[7] = GenoMarginalDistr[1][1][0] - GenoJointDistr[1] - GenoJointDistr[4];
			GenoJointDistr[8] = GenoMarginalDistr[1][2][0] - GenoJointDistr[2] - GenoJointDistr[5];

			//for ctrl
			GenoJointDistr[11] = GenoMarginalDistr[0][0][1] - GenoJointDistr[9] - GenoJointDistr[10];
			GenoJointDistr[14] = GenoMarginalDistr[0][1][1] - GenoJointDistr[12] - GenoJointDistr[13];
			GenoJointDistr[15] = GenoMarginalDistr[1][0][1] - GenoJointDistr[9] - GenoJointDistr[12];
			GenoJointDistr[16] = GenoMarginalDistr[1][1][1] - GenoJointDistr[10] - GenoJointDistr[13];
			GenoJointDistr[17] = GenoMarginalDistr[1][2][1] - GenoJointDistr[11] - GenoJointDistr[14];

			// The joint distribution has been collected.
			// Start to calculate the statistic.
			int n_i_j_dot;
			float ls_lb_case, ls_lb_ctrl, ls_lb = 0;
			for (i1 = 0; i1<3; i1++)
			{
				for(i2 = 0; i2<3; i2++)
				{
					n_i_j_dot = GenoJointDistr[i1 + 3*i2] + GenoJointDistr[i1 + 3*i2 + 9];
					ls_lb_case = 0;
					ls_lb_ctrl = 0;
					if(GenoJointDistr[i1 + 3*i2] > 0) {
						ls_lb_case = GenoJointDistr[i1 + 3*i2] * log((float)(GenoJointDistr[i1 + 3*i2]) * nSample / n_i_j_dot / nCase);
					}
					if(GenoJointDistr[i1 + 3*i2 + 9] > 0) {
						ls_lb_ctrl = GenoJointDistr[i1 + 3*i2 + 9] * log((float)(GenoJointDistr[i1 + 3*i2 + 9]) * nSample / n_i_j_dot / nCtrl);
					}
					ls_lb += ls_lb_case + ls_lb_ctrl;
				}
			}
			ls_lb *= 2;
			sharedThreadResult[id] = 0;
			//threshold = 10.0;
			if(ls_lb > threshold) {
				sharedThreadResult[id] = 1;
			}
			__syncthreads();
			for(i1 = numLabelGroup / 2; i1 > 0; i1 /= 2) {
				if(id < i1) {
					sharedThreadResult[id] += sharedThreadResult[id + i1];
				}
				__syncthreads();
			}
			if(id == 0)
				sharedPermResult[i] += sharedThreadResult[0];
			__syncthreads();
		}
		cnt++;
	}
	
	for(int i = 0; i < (nSNP / 2); i += numLabelGroup) {
		if(i + id < nSNP / 2) {
			permResult[blockIdx.x * (nSNP / 2) + i + id] = sharedPermResult[i + id];
		}
	}
	
}

// Kernel function to permute for marginal association test.
extern "C" int cuda_permute_marginal(uint64 *hostGenoData, int nLongIntSample, int nSample, int nCase, int nSNP, int *position, double *statistics, double *p_value, char *outputPrefix, struct KernelParams kernelParams) {
	int threadNum = nSNP;
	int blockNum = 2;
	int numPermutation = 1000;
	if(kernelParams.numBlock > 0)
		blockNum = kernelParams.numBlock;
	if(kernelParams.numThread > 0)
		threadNum = kernelParams.numThread;
	if(kernelParams.numPermutation > 0)
		numPermutation = kernelParams.numPermutation;
	dim3 threads(threadNum, 1, 1);
    dim3 grid(blockNum, 1, 1);
	cudaError_t err;

	// Alloc memory for gene data and bind it to texture.
	uint64 *devGenoData;
	int genoDataSize = 3 * nSNP * nLongIntSample * sizeof(uint64);
	cudaMalloc((void**)&devGenoData, genoDataSize);
	cudaMemcpy(devGenoData, hostGenoData, genoDataSize, cudaMemcpyHostToDevice);
	//cudaBindTexture(0, texGenoData, devGenoData, genoDataSize);

	// Alloc for the statistics.
	float *devSta;
	cudaMalloc((void**)&devSta, nSNP * sizeof(float));
	float *hostSta = (float *)malloc(nSNP * sizeof(float));
	for(int i = 0; i < nSNP; i++) {
		hostSta[i] = statistics[i];
	}
	cudaMemcpy(devSta, hostSta, nSNP * sizeof(float), cudaMemcpyHostToDevice);
	free(hostSta);

	// Malloc for labels
	uint64 *devLabels;
	CUDA_CALL(cudaMalloc((void **)&devLabels, blockNum * sizeof(uint64) * nLongIntSample * numLabelGroup));


	// Malloc for MTGP states	
	curandStateMtgp32 *devMTGPStates;
	CUDA_CALL(cudaMalloc((void **)&devMTGPStates, blockNum * sizeof(curandStateMtgp32)));
	// Malloc for MT parameters
	mtgp32_kernel_params *devKernelParams;
	CUDA_CALL(cudaMalloc((void**)&devKernelParams, sizeof(mtgp32_kernel_params)));
    CUDA_CALL(curandMakeMTGP32Constants(mtgp32dc_params_fast_11213, devKernelParams));
    
    /* Initialize one state per thread block */
	// Generate a random number to be used as seed.
	unsigned long long seed = 0;
	srand(time(0));
	unsigned long long most = rand(), least = rand();
	seed = (most << 32ULL) + least;
	printf("seed = %llu\n", seed);
    CUDA_CALL(curandMakeMTGP32KernelState(devMTGPStates, mtgp32dc_params_fast_11213, devKernelParams, blockNum, seed));
    /* State setup is complete */

	// Malloc for XORWOW states
	curandState *devStates;
	int sizeState = blockNum * sizeof(curandState);
	CUDA_CALL(cudaMalloc((void **)&devStates, sizeState));

	// Malloc for results.
	int totalThreadNum = blockNum * threadNum;
	// Malloc for marginal permutation result
	int *devPermResult;
	CUDA_CALL(cudaMalloc((void **)&devPermResult, totalThreadNum * sizeof(int)));
	int *hostPermResult = (int *)malloc(totalThreadNum * sizeof(int));
	int *hostExtreme = (int *)malloc(nSNP * sizeof(int));
	memset(hostPermResult, 0, totalThreadNum * sizeof(int));
	memset(hostExtreme, 0, nSNP * sizeof(int));
	CUDA_CALL(cudaMemcpy(devPermResult, hostPermResult, totalThreadNum * sizeof(int),  cudaMemcpyHostToDevice));

	cudaEvent_t start, stop;
	float time;
	printf ("Start to permute. threadNum: %d, blockNum: %d\n", threadNum, blockNum);
	cudaEventCreate(&start);
	cudaEventCreate(&stop);
	cudaEventRecord(start, 0);
	int cnt = 1;
	long long totalPermutation = 0;

	char *outputFilename = strcat(outputPrefix, "_output.txt");
	FILE *output_fp = fopen(outputFilename, "a+");
	if(output_fp == NULL)
	{
		fprintf(stderr, "can't open input file %s\n", outputFilename);
		return -1;
	}
	fprintf(output_fp, "Start to permute. threadNum: %d, blockNum: %d\n", threadNum, blockNum);
	fflush(output_fp);

	int device;
	cudaDeviceProp deviceProp;
	cudaGetDevice(&device); 
	cudaGetDeviceProperties(&deviceProp, device);
	printf("Using device %d: %s \n", device, deviceProp.name);
	
	while(cnt < 1000000) {
		// Calculate the size of shared memory needed.
		// Shared memory: labels + statistic results + random generator states
		int sharedSize = nLongIntSample * sizeof(uint64);
		//permute_marginal_kernel<<<grid, threads, sharedSize>>>(devGenoData, nSample, nLongIntSample, nCase, nSNP, devMTGPStates, numPermutation, devPermResult, devSta, devLabels);
		permute_marginal_kernel<<<grid, threads>>>(devGenoData, nSample, nLongIntSample, nCase, nSNP, devMTGPStates, numPermutation, devPermResult, devSta, devLabels);
		//permute_marginal_XORWOW_kernel<<<grid, threads>>>(devGenoData, nSample, nLongIntSample, nCase, nSNP, devStates, numPermutation, devPermResult, devSta, devLabels);
		err = cudaGetLastError();
		if (err != cudaSuccess) 
			printf("Error: %s\n", cudaGetErrorString(err));
		//test_kernel<<<grid, threads, sharedSize>>>(nSNP * blockNum * sizeof(int), marginalPermResult);
		//cudaDeviceSynchronize();
		//CUDA_CALL(cudaMemcpy(hostPermResult, devPermResult, totalThreadNum * sizeof(int),  cudaMemcpyDeviceToHost));
		cudaMemcpy(hostPermResult, devPermResult, totalThreadNum * sizeof(int),  cudaMemcpyDeviceToHost);
		err = cudaGetLastError();
		if (err != cudaSuccess) {
			printf("Error: %s\n", cudaGetErrorString(err));
			return -1;
		}
		for(int j = 0; j < blockNum; j++) {
			for(int i = 0; i < nSNP; i++) {
				hostExtreme[i] += hostPermResult[nSNP * j + i];
			}
		}
		cudaEventRecord(stop, 0);
		cudaEventSynchronize(stop);
		cudaEventElapsedTime(&time, start, stop);
		int minute = time / 1000 / 60;
		time = time / 1000 - minute * 60;
		int hour = minute / 60;
		minute = minute % 60;
		int day = hour / 24;
		hour = hour % 24;
		printf("Time elapsed: %d d %d h %d m %f s\n", day, hour, minute, time);
		fprintf(output_fp, "Time elapsed: %d d %d h %d m %f s\n", day, hour, minute, time);
		long long perm_in_loop = numPermutation * blockNum;
		totalPermutation = cnt * perm_in_loop * numLabelGroup;
		printf("Permutation: %llu, Extreme times: %d\n", totalPermutation, hostExtreme[0]);
		fprintf(output_fp, "Permutation: %llu, Extreme times: %d\n", totalPermutation, hostExtreme[0]);
		//printf("cnt: %d, totalThreadNum: %d, numPermutation: %llu\n", cnt, totalThreadNum, numPermutation);
		for(int i = 0; i < nSNP; i++) {
			/*float p_val = (float)hostExtreme[i] / totalPermutation;
			if(p_val > 0) {
				printf("SNP %d P value: %f\n", i, p_val);
				fprintf(output_fp, "SNP %d P value: %f\n", i, p_val);
			}*/
			if(hostExtreme[i] > 0) {
				printf("SNP %d extreme: %d\n", i, hostExtreme[i]);
				fprintf(output_fp, "SNP %d extreme: %d\n", i, hostExtreme[i]);
			}
		}
		fflush(output_fp);
		cnt++;
	}
	fclose(output_fp);
	free(hostPermResult);
	free(hostExtreme);
	cudaFree(devPermResult);
	//cudaFree(devKernelParams);
	//cudaFree(devMTGPStates);
	return 1;
}

// Kernel function to permute for interaction test.
extern "C" int cuda_permute_interaction(uint64 *hostGenoData, int nLongIntSample, int nSample, int nCase, int nSNP, int *position, double *statistics, double *p_value, char *outputPrefix, struct KernelParams kernelParams) {
	int threadNum = numLabelGroup;
	int nPair = nSNP / 2;
	int blockNum = 128;
	int numPermKernel = 10000;
	int numPermutation = 1e9;
	if(kernelParams.numBlock > 0)
		blockNum = kernelParams.numBlock;
	if(kernelParams.numThread > 0)
		threadNum = kernelParams.numThread;
	if(kernelParams.numPermutation > 0)
		numPermutation = kernelParams.numPermutation;
	long long perm_in_loop = numPermKernel * blockNum * numLabelGroup;
	int numLoop = numPermutation / perm_in_loop + 1;
	dim3 threads(threadNum, 1, 1);
    dim3 grid(blockNum, 1, 1);
	cudaError_t err;

	// Alloc memory for gene data and bind it to texture.
	uint64 *devGenoData;
	int genoDataSize = 3 * nSNP * nLongIntSample * sizeof(uint64);
	cudaMalloc((void**)&devGenoData, genoDataSize);
	cudaMemcpy(devGenoData, hostGenoData, genoDataSize, cudaMemcpyHostToDevice);
	/*if(nPair <= 16)
		cudaMemcpyToSymbol(constGenoData, hostGenoData, genoDataSize);*/

	// Alloc for the statistics.
	float *devSta;
	cudaMalloc((void**)&devSta, nPair * sizeof(float));
	float *hostSta = (float *)malloc(nPair * sizeof(float));
	for(int i = 0; i < nPair; i++) {
		hostSta[i] = statistics[i];
	}
	cudaMemcpyToSymbol(constSta, hostSta, nPair * sizeof(float));
	free(hostSta);

	// Malloc for labels
	uint64 *devLabels;
	CUDA_CALL(cudaMalloc((void **)&devLabels, blockNum * sizeof(uint64) * nLongIntSample * numLabelGroup));


	// Malloc for MTGP states	
	curandStateMtgp32 *devMTGPStates;
	CUDA_CALL(cudaMalloc((void **)&devMTGPStates, blockNum * sizeof(curandStateMtgp32)));
	// Malloc for MT parameters
	mtgp32_kernel_params *devKernelParams;
	CUDA_CALL(cudaMalloc((void**)&devKernelParams, sizeof(mtgp32_kernel_params)));
    CUDA_CALL(curandMakeMTGP32Constants(mtgp32dc_params_fast_11213, devKernelParams));
    
    /* Initialize one state per thread block */
	// Generate a random number to be used as seed.
	unsigned long long seed = 0;
	srand(time(0));
	unsigned long long most = rand(), least = rand();
	seed = (most << 32ULL) + least;
	printf("seed = %llu\n", seed);
    CUDA_CALL(curandMakeMTGP32KernelState(devMTGPStates, mtgp32dc_params_fast_11213, devKernelParams, blockNum, seed));
    /* State setup is complete */

	// Malloc for results.
	int *devPermResult, *hostPermResult, *hostExtreme;
	CUDA_CALL(cudaMalloc((void **)&devPermResult, blockNum * nPair * sizeof(int)));
	hostPermResult = (int *)malloc(blockNum * nPair * sizeof(int));
	hostExtreme = (int *)malloc(nPair * sizeof(int));
	memset(hostPermResult, 0, blockNum * nPair * sizeof(int));
	memset(hostExtreme, 0, nPair * sizeof(int));
	CUDA_CALL(cudaMemcpy(devPermResult, hostPermResult, nPair * sizeof(int),  cudaMemcpyHostToDevice));

	cudaEvent_t start, stop;
	float time;
	printf ("Start to permute. threadNum: %d, blockNum: %d\n", threadNum, blockNum);
	cudaEventCreate(&start);
	cudaEventCreate(&stop);
	cudaEventRecord(start, 0);
	int cnt = 1;
	long long totalPermutation = 0;

	char *outputFilename = strcat(outputPrefix, "_output.txt");
	FILE *output_fp = fopen(outputFilename, "a+");
	if(output_fp == NULL)
	{
		fprintf(stderr, "can't open input file %s\n", outputFilename);
		return -1;
	}
	fprintf(output_fp, "Start to permute. threadNum: %d, blockNum: %d\n", threadNum, blockNum);
	fflush(output_fp);

	int device;
	cudaDeviceProp deviceProp;
	cudaGetDevice(&device); 
	cudaGetDeviceProperties(&deviceProp, device);
	printf("Using device %d: %s \n", device, deviceProp.name);
	// Use 48k L1 and 16k shared memory.
	cudaFuncSetCacheConfig(permute_interaction_kernel, cudaFuncCachePreferL1);
	while(cnt <= numLoop) {
		// Calculate the size of shared memory needed.
		int sharedSize = numLabelGroup * sizeof(int) + nPair * sizeof(int) + 6 * nLongIntSample * sizeof(uint64);
		if(sharedSize < 5368)
			sharedSize = 5368;
		permute_interaction_kernel<<<grid, threads, sharedSize>>>(devGenoData, nSample, nLongIntSample, nCase, nSNP, devMTGPStates, numPermKernel, devPermResult, devLabels);
		err = cudaGetLastError();
		if (err != cudaSuccess) {
			printf("Error: %s\n", cudaGetErrorString(err));
			return -1;
		}
		cudaMemcpy(hostPermResult, devPermResult, blockNum * nPair * sizeof(int),  cudaMemcpyDeviceToHost);
		err = cudaGetLastError();
		if (err != cudaSuccess) {
			printf("Error: %s\n", cudaGetErrorString(err));
			return -1;
		}
		for(int j = 0; j < blockNum; j++) {
			for(int i = 0; i < nPair; i++) {
				hostExtreme[i] += hostPermResult[nPair * j + i];
			}
		}
		cudaEventRecord(stop, 0);
		cudaEventSynchronize(stop);
		cudaEventElapsedTime(&time, start, stop);
		int minute = time / 1000 / 60;
		time = time / 1000 - minute * 60;
		int hour = minute / 60;
		minute = minute % 60;
		int day = hour / 24;
		hour = hour % 24;
		fprintf(output_fp, "Time elapsed: %d d %d h %d m %f s\n", day, hour, minute, time);
		totalPermutation = cnt * perm_in_loop;
		fprintf(output_fp, "Permutation: %llu\n", totalPermutation);
		if(kernelParams.output2screen == 1) {
			printf("Time elapsed: %d d %d h %d m %f s\n", day, hour, minute, time);
			printf("Permutation: %llu\n", totalPermutation);
		}
		for(int i = 0; i < nPair; i++) {
			if(hostExtreme[i] > 0) {
				if(kernelParams.output2screen == 1) {
					printf("SNP %d extreme: %d\n", i, hostExtreme[i]);
				}
				fprintf(output_fp, "SNP %d extreme: %d\n", i, hostExtreme[i]);
			}
		}
		fflush(output_fp);
		cnt++;
	}
	fclose(output_fp);
	free(hostPermResult);
	free(hostExtreme);
	cudaFree(devPermResult);
	cudaFree(devKernelParams);
	cudaFree(devMTGPStates);
	return 0;
}
